{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lqlWcVtxYRoS"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow GNN Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BjwBK3YmYR1B"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AlRWj8VEcw_h"
      },
      "source": [
        "# Node Classification via TF-GNN\n",
        "This is a hands-on tutorial that explains how to run node classification over Graph Datasets. In particular, it covers:\n",
        "\n",
        "* How to load into TF-GNN standard academic graph datasets (e.g., OGBN, Cora, Pubmed, Citeseer).\n",
        "* Shows alternative to either train on full-graph, or using on-the-fly sampling.\n",
        "* Shows how to choose among simple models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yfgS8ZHXYD5E"
      },
      "source": [
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/tutorials/log_2022/code_tutorial_1_tfgnn_single_machine.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/gnn/blob/main/examples/tutorials/log_2022/code_tutorial_1_tfgnn_single_machine.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sb8FaHmRUdMO"
      },
      "source": [
        "## pip-install TF-GNN"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lW3wYljhUfjr"
      },
      "outputs": [],
      "source": [
        "!pip install ogb\n",
        "!pip install tensorflow_gnn --pre"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4NRgg_kXd2bb"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Mk8DlnFOfCk"
      },
      "outputs": [],
      "source": [
        "# Standard imports\n",
        "import tensorflow as tf\n",
        "import tensorflow_gnn as tfgnn\n",
        "import functools\n",
        "import collections\n",
        "from typing import Mapping, List\n",
        "import functools\n",
        "import math"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5cvU_ZMdPEO5"
      },
      "source": [
        "In addition to standard imports, this tutorial also needs these imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8OLmGEs_O8EG"
      },
      "outputs": [],
      "source": [
        "# To use standard datasets that fit in memory.\n",
        "from tensorflow_gnn.experimental.in_memory import datasets\n",
        "\n",
        "# Implementations of example GNN models.\n",
        "from tensorflow_gnn.experimental.in_memory import models\n",
        "\n",
        "# Converts `tfgnn.GraphTensor` to (`tfgnn.GraphTensor`, `tf.Tensor`)\n",
        "# with second item containing task labels.\n",
        "from tensorflow_gnn.experimental.in_memory import reader_utils\n",
        "\n",
        "# For on-the-fly subgraph sampling.\n",
        "from tensorflow_gnn.sampler import sampling_spec_builder"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5aNCyg_fJFo"
      },
      "source": [
        "## Configure dataset, model hyperparameters, and training loop."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f1IVGap4Pvqh"
      },
      "outputs": [],
      "source": [
        "#@title Configurations\n",
        "#@markdown Dataset\n",
        "sampling_style = 'inmemory'  #@param ['nosampling','inmemory']\n",
        "dataset_name = 'ogbn-arxiv'  #@param ['ogbn-arxiv', 'ogbn-mag', 'cora', 'citeseer', 'pubmed']\n",
        "sampling_num_seed_nodes = 100  #@param int  # no-op if sampling_style == 'nosampling'\n",
        "# If set, validation set will be used as part of training set. This is allowed\n",
        "# by some datasets, such as, OGB, *after* finalizing hyperparameters.\n",
        "valid_as_train = True  #@param\n",
        "\n",
        "#@markdown Model\n",
        "\n",
        "# Models are implemented (by name) in examples/models.py\n",
        "model_name = 'GCN'  #@param ['GCN', 'Simple', 'JKNet', 'GATv2', 'ResidualGCN']\n",
        "\n",
        "# Passed to the model function (or constructor). Each model could accept\n",
        "# different parameters.\n",
        "# All models accept \"depth\" (i.e., number of GNN layers).\n",
        "model_kwargs = {\"dropout\":0.2, \"depth\":2} #@param\n",
        "\n",
        "# Converts graph to undirected.\n",
        "undirect_graph = True #@param\n",
        "\n",
        "# Only activated when nodes have no feature vectors.\n",
        "# Every node will be embedded in this many dimensions.\n",
        "embed_featureless_node_dim = 32  #@param int  # For node sets without 'feat' attribute, they will be embedded.\n",
        "\n",
        "\n",
        "#@markdown Training loop\n",
        "\n",
        "# Number of training epochs.\n",
        "epochs = 50  #@param int\n",
        "\n",
        "# Every this many epochs, evaluation is computed.\n",
        "eval_every = 1  #@param int\n",
        "\n",
        "# Learning rate for ADAM optimizer.\n",
        "learning_rate = 1e-2 #@param float\n",
        "\n",
        "# L2 regularization only applies to kernel weights (not biases).\n",
        "l2_regularization = 1e-5\n",
        "\n",
        "print(f'training {model_name} (with sampling={sampling_style}, num_seed_nodes={sampling_num_seed_nodes}) on dataset {dataset_name}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ntJtJLAFSlN5"
      },
      "source": [
        "## Dataset\n",
        "\n",
        "The cell loads the graph data as a tf.data.Dataset, so that it can be passed into the model for training. The following variables are created:\n",
        "\n",
        "* `train_ds`, `valid_ds`, `test_ds`, all are `tf.data.Dataset` instances. If training on full-graph, then the datasets will contain only one element: the full-graph, without any stochasticity. If training with sampling, then the dataset will contain subgraph samples.\n",
        "\n",
        "* `graph_spec`, which can be used to create (symbolic) input layer, representing `GraphTensor`, that can be used to build keras model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NEowzJ8nQxhH"
      },
      "outputs": [],
      "source": [
        "# Load our input graph in TF-GNN format, which\n",
        "# we use to create tf.data.Dataset.\n",
        "graph_data = datasets.get_in_memory_graph_data(dataset_name)\n",
        "assert isinstance(graph_data, datasets.NodeClassificationGraphData)\n",
        "num_classes = graph_data.num_classes()\n",
        "print('loaded %s in graph_data with num_classes=%i' % (dataset_name, num_classes))\n",
        "\n",
        "\n",
        "\n",
        "# Prepare datasets: train, test, validate\n",
        "train_split = ['train']   # might change.\n",
        "valid_split = 'valid'      # might chang with above.\n",
        "test_split = 'test'  # constant.\n",
        "if valid_as_train:\n",
        "  train_split.append('validation')\n",
        "  valid_split = 'test'\n",
        "\n",
        "graph_data = (graph_data\n",
        "              .with_labels_as_features(True)\n",
        "              .with_self_loops(True)\n",
        "              .with_undirected_edges(undirect_graph))\n",
        "print(graph_data.edge_sets())\n",
        "\n",
        "# Helpers for in-memory sampling\n",
        "# Returns map: node set name -\u003e list of edge names outgoing from node set.\n",
        "def edge_set_names_by_source(\n",
        "    graph: tfgnn.GraphSchema\n",
        "    ) -\u003e Mapping[tfgnn.NodeSetName, List[tfgnn.EdgeSetName]]:\n",
        "  results = collections.defaultdict(list)\n",
        "  for edge_set_name, edge_set_schema in graph.edge_sets.items():\n",
        "    results[edge_set_schema.source].append(edge_set_name)\n",
        "  return results\n",
        "\n",
        "\n",
        "# Creates SamplingSpec proto by instructing sampler to traverse *every* edge set\n",
        "# connecting labeled-node, recursively until depth `depth`, sampling `fanout`\n",
        "# neighbor nodes from *each* edge set.\n",
        "def make_sampling_spec_with_dfs(\n",
        "    graph_data: datasets.NodeClassificationGraphData,\n",
        "    fanout=5, depth=2):\n",
        "  graph_schema = graph_data.graph_schema()\n",
        "  edge_sets_by_src_node_set = edge_set_names_by_source(graph_schema)\n",
        "  spec_builder = sampling_spec_builder.SamplingSpecBuilder(\n",
        "      graph_schema,\n",
        "      default_strategy=sampling_spec_builder.SamplingStrategy.RANDOM_UNIFORM)\n",
        "  spec_builder = spec_builder.seed(graph_data.labeled_nodeset)\n",
        "\n",
        "  def _recursively_sample_all_edge_sets(\n",
        "      cur_node_set_name, sampling_step, remaining_depth):\n",
        "    if remaining_depth == 0:\n",
        "      return\n",
        "    for edge_set_name in edge_sets_by_src_node_set[cur_node_set_name]:\n",
        "      edge_set_schema = graph_schema.edge_sets[edge_set_name]\n",
        "      _recursively_sample_all_edge_sets(\n",
        "          edge_set_schema.target,\n",
        "          sampling_step.sample(fanout, edge_set_name),\n",
        "          remaining_depth - 1)\n",
        "\n",
        "  _recursively_sample_all_edge_sets(\n",
        "      graph_data.labeled_nodeset, spec_builder, depth)\n",
        "\n",
        "  return spec_builder.build()\n",
        "\n",
        "# For each edge (i, j), adds edge (j, i), regardless if it already exists.\n",
        "# This function can be folded into TF-GNN proper, but left here for demonstration.\n",
        "def undirect_subgraph(graph_tensor: tfgnn.GraphTensor, label) -\u003e (tfgnn.GraphTensor, tf.Tensor):\n",
        "  edge_sets = {}\n",
        "  for es_name, es in graph_tensor.edge_sets.items():\n",
        "    if es.adjacency.source_name == es.adjacency.target_name:\n",
        "      src, tgt = es.adjacency.source, es.adjacency.target\n",
        "      node_set_name = es.adjacency.source_name\n",
        "      new_src = tf.concat([src, tgt], axis=0)\n",
        "      new_tgt = tf.concat([tgt, src], axis=0)\n",
        "      edge_sets[es_name] = tfgnn.EdgeSet.from_fields(\n",
        "          sizes=es.sizes*2,\n",
        "          adjacency=tfgnn.Adjacency.from_indices(\n",
        "              source=(node_set_name, new_src),\n",
        "              target=(node_set_name, new_tgt)))\n",
        "    else:\n",
        "      edge_sets[es_name] = es\n",
        "\n",
        "  graph_tensor = tfgnn.GraphTensor.from_pieces(\n",
        "      context=graph_tensor.context,\n",
        "      node_sets=graph_tensor.node_sets,\n",
        "      edge_sets=edge_sets)\n",
        "      \n",
        "  return graph_tensor, label\n",
        "\n",
        "\n",
        "def as_dataset(graph_data, pop_labels_from_graph: bool = True):\n",
        "    graph_data = graph_data.with_labels_as_features(True)\n",
        "    dataset = tf.data.Dataset.from_tensors(graph_data.as_graph_tensor())\n",
        "\n",
        "    if pop_labels_from_graph:\n",
        "      num_classes = graph_data.num_classes()\n",
        "      dataset = dataset.map(\n",
        "          functools.partial(reader_utils.pop_labels_from_graph, num_classes))\n",
        "\n",
        "    return dataset\n",
        "\n",
        "\n",
        "if sampling_style == 'nosampling':\n",
        "  train_ds = as_dataset(graph_data.with_split(train_split))\n",
        "  valid_ds = as_dataset(graph_data.with_split(valid_split))\n",
        "  test_ds = graph_data.with_split(test_split)\n",
        "\n",
        "  total_train_steps = epochs  # Every step includes all graph nodes \u0026 edges.\n",
        "  train_steps_per_epoch = 1\n",
        "  validation_steps = 1  # One-step of validation evaluates all validation nodes.\n",
        "  total_test_steps = 1\n",
        "elif sampling_style == 'inmemory':\n",
        "  sampling_spec = make_sampling_spec_with_dfs(graph_data)\n",
        "\n",
        "  from tensorflow_gnn.experimental.in_memory import int_arithmetic_sampler as ia_sampler\n",
        "  sampling_args = dict(\n",
        "      sampling_spec=sampling_spec,\n",
        "      num_seed_nodes=sampling_num_seed_nodes,\n",
        "      sampling_mode=ia_sampler.EdgeSampling.WITH_REPLACEMENT)\n",
        "  sampler_class = ia_sampler.NodeClassificationGraphSampler\n",
        "  \n",
        "  train_ds = sampler_class(graph_data.with_split(train_split)).as_dataset(\n",
        "      **sampling_args).map(undirect_subgraph)\n",
        "  valid_ds = sampler_class(graph_data.with_split(valid_split)).as_dataset(\n",
        "      **sampling_args).map(undirect_subgraph)\n",
        "  test_ds = sampler_class(graph_data.with_split(test_split)).as_dataset(\n",
        "      **sampling_args).map(undirect_subgraph)\n",
        "  size_train = graph_data.node_split().train.shape[0]\n",
        "  train_steps_per_epoch = math.ceil(size_train / sampling_num_seed_nodes)\n",
        "  validation_steps = 10  # Evaluate small subset of validation.\n",
        "  total_test_steps = 100\n",
        "  print('\\nsampling_spec:\\n', sampling_spec)\n",
        "\n",
        "\n",
        "# Make sure we can take an example of the dataset.\n",
        "example = next(iter(train_ds.take(1)))\n",
        "graph_tensor, labels = example\n",
        "\n",
        "# Crucially, above `graph_tensor` will be used to create `graph_spec` that will\n",
        "# be used as a placeholder to instantiate the model.\n",
        "graph_spec = graph_tensor.spec.relax(num_edges=True, num_nodes=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D1HQmzEBVyCQ"
      },
      "source": [
        "# Model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzpKuzckRpYn"
      },
      "outputs": [],
      "source": [
        "# Customize model.\n",
        "model_prefers_undirected, model = models.make_model_by_name(\n",
        "      model_name, num_classes, l2_coefficient=l2_regularization,\n",
        "      model_kwargs=model_kwargs)\n",
        "\n",
        "node_counts = graph_data.node_counts()\n",
        "\n",
        "# Input pre-processing layer. Can be trainable.\n",
        "def set_init_node_features(node_set, node_set_name):\n",
        "  # If there is a feature already, return it.\n",
        "  if 'feat' in node_set.features:\n",
        "    return node_set['feat']\n",
        "\n",
        "  # Otherwise, we can embed the node_set.\n",
        "  if embed_featureless_node_dim \u003e 0:\n",
        "    embedding_layer = tf.keras.layers.Embedding(\n",
        "            node_counts[node_set_name],\n",
        "            embed_featureless_node_dim)\n",
        "\n",
        "# Input is placeholder\n",
        "input_graph = tf.keras.layers.Input(type_spec=graph_spec)\n",
        "graph = input_graph\n",
        "\n",
        "# Transformations.\n",
        "graph = tfgnn.keras.layers.MapFeatures(  # Pre-processing.\n",
        "    node_sets_fn=set_init_node_features)(graph)\n",
        "graph = model(graph)  # GNN.\n",
        "\n",
        "# Readout tf.Tensor.\n",
        "h = reader_utils.readout_seed_node_features(\n",
        "    graph, node_set_name=graph_data.labeled_nodeset)\n",
        "\n",
        "# Post-processing (optional)\n",
        "# h = tf.keras.layers.Dense(num_classes)(h)\n",
        "\n",
        "# Capture the computation graph input_graph -\u003e h as a `tf.keras.Model`.\n",
        "keras_model = tf.keras.Model(inputs=input_graph, outputs=h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1yTwpLKoVsfN"
      },
      "source": [
        "# Optimization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P84chPiQa9Gy"
      },
      "outputs": [],
      "source": [
        "# Optimizer Function\n",
        "# optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)\n",
        "\n",
        "opt =  tf.keras.optimizers.Adam(learning_rate)\n",
        "loss = tf.keras.losses.CategoricalCrossentropy(\n",
        "    from_logits=True, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)\n",
        "\n",
        "keras_model.compile(opt, loss=loss, metrics=['acc'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zm1-ulRabRWL"
      },
      "outputs": [],
      "source": [
        "keras_model.fit(\n",
        "      train_ds,\n",
        "      epochs=epochs,\n",
        "      steps_per_epoch=train_steps_per_epoch,\n",
        "      validation_data=valid_ds.repeat(),\n",
        "      validation_steps=validation_steps,\n",
        "      validation_freq=eval_every)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4gZ1q5NBv4FL"
      },
      "source": [
        "## Test model performance"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kkHDaxVWcNIC"
      },
      "outputs": [],
      "source": [
        "total_correct = 0\n",
        "total_size = 0\n",
        "total_test_steps = 100\n",
        "test_labels = graph_data.test_labels()\n",
        "\n",
        "seed_feat_name = 'seed_nodes.' + graph_data.labeled_nodeset\n",
        "\n",
        "for test_step, (graph, unused_labels) in enumerate(test_ds):\n",
        "  out = keras_model(graph)\n",
        "\n",
        "  seed_pos = graph.context[seed_feat_name]\n",
        "  id_feat = graph.node_sets[graph_data.labeled_nodeset]['#id']\n",
        "  seed_ids = tf.gather(id_feat , seed_pos)\n",
        "  labels = tf.gather(test_labels, seed_ids)  \n",
        "  pred_y = tf.cast(tf.argmax(out, -1), labels.dtype)\n",
        "\n",
        "\n",
        "  is_correct_label = pred_y == labels\n",
        "  int_is_correct_label = tf.cast(is_correct_label, tf.int64)\n",
        "  total_correct += tf.reduce_sum(int_is_correct_label)\n",
        "  total_size += tf.reduce_sum(tf.ones_like(int_is_correct_label))\n",
        "  if test_step \u003e= total_test_steps:\n",
        "    break\n",
        "\n",
        "  print('Test accuracy = %f' % (\n",
        "      tf.cast(total_correct, tf.float32) / tf.cast(total_size, tf.float32)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fSNSsMi1WMDf"
      },
      "source": [
        "## Your assignment\n",
        "Do simple change: (# layers, regularization, hidden dimensions)\n",
        "\n",
        "Try ResNet."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_g1EULpWw1i"
      },
      "source": [
        "## Bonus: Import your custom dataset\n",
        "This works for datasets that fit in-memory:\n",
        "\n",
        "https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/experimental/in_memory/datasets.py"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "lqlWcVtxYRoS"
      ],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1jVcbsbCQuWAqLoYaZKSgS18kR5d2LbnJ",
          "timestamp": 1669237441923
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
